{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from diffusers import PixArtAlphaPipeline, DiffusionPipeline, StableVideoDiffusionPipeline\n",
    "from diffusers.utils import load_image, export_to_video, make_image_grid\n",
    "from cache_diffusion import cachify\n",
    "from cache_diffusion.utils import SVD_DEFAULT_CONFIG, SDXL_DEFAULT_CONFIG, PIXART_DEFAULT_CONFIG"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SDXL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's load the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = DiffusionPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
    "    torch_dtype=torch.float16,\n",
    "    variant=\"fp16\",\n",
    "    use_safetensors=True,\n",
    ")\n",
    "pipe = pipe.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inference_steps = 20\n",
    "prompt = \"beautiful lady, (freckles), big smile, blue eyes, short hair, dark makeup, hyperdetailed photography, soft light, head and shoulders portrait, cover\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our pipeline requires just a single API call to perform caching.\n",
    "\n",
    "Let's disable the caching and run the baseline model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, num_inference_steps, SDXL_DEFAULT_CONFIG)\n",
    "cachify.disable(pipe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "baseline_img_20_steps = pipe(prompt=prompt, num_inference_steps=num_inference_steps, generator=generator).images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also reduce the number of steps to achieve similar latency as using cache diffusion. However, you will notice that the image quality is not as good."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "baseline_img_11_steps = pipe(prompt=prompt, num_inference_steps=11, generator=generator).images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's enable the caching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.enable(pipe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "cache_img = pipe(prompt=prompt, num_inference_steps=num_inference_steps, generator=generator).images[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "make_image_grid([baseline_img_20_steps, cache_img, baseline_img_11_steps], 1, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PixArt-Alpha"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = PixArtAlphaPipeline.from_pretrained(\n",
    "    \"PixArt-alpha/PixArt-XL-2-1024-MS\", torch_dtype=torch.float16\n",
    ")\n",
    "pipe = pipe.to(\"cuda\")\n",
    "num_inference_steps = 30\n",
    "prompt = \"a small cactus with a happy face in the Sahara desert\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, num_inference_steps, PIXART_DEFAULT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.Generator(device=\"cuda\").manual_seed(2946901)\n",
    "pipe(prompt=prompt, generator=generator, num_inference_steps=num_inference_steps).images[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SVD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = StableVideoDiffusionPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-video-diffusion-img2vid-xt\", torch_dtype=torch.float16, variant=\"fp16\"\n",
    ")\n",
    "pipe.enable_model_cpu_offload()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the conditioning image\n",
    "image = load_image(\n",
    "    \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png\"\n",
    ")\n",
    "image = image.resize((1024, 576))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = torch.manual_seed(42)\n",
    "num_inference_steps = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cachify.prepare(pipe, num_inference_steps, SVD_DEFAULT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]\n",
    "\n",
    "export_to_video(frames, \"generated.mp4\", fps=7)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ammo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.-1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
